Xtoken/off policy distillation gh#2245
Open
avenkateshha wants to merge 17 commits intoNVIDIA-NeMo:mainfrom
Open
Xtoken/off policy distillation gh#2245avenkateshha wants to merge 17 commits intoNVIDIA-NeMo:mainfrom
avenkateshha wants to merge 17 commits intoNVIDIA-NeMo:mainfrom
Conversation
added 17 commits
April 9, 2026 17:46
- Off-policy distillation pipeline (teacher Llama-3.1-8B, student Llama-3.2-1B) with arrow dataset support and inline MATH/MMLU generation-based evaluation - Compute distributed log_softmax before top-k for correct KL divergence - Add CUDA IPC buffer mechanism to avoid Ray object store bottleneck for large top-k logprob tensors (based on dtensor_sharath.py approach) - Update loss function to skip re-normalization of teacher log probabilities - Add submit scripts, configs, and eval benchmarks Made-with: Cursor
Made-with: Cursor
- Flatten teacher IPC data structure and use mb_idx*mbs indexing instead of cumulative mb_offset/mb_size for microbatch slicing - Add log_softmax for teacher top-k logits in standard (non-IPC) path - Restore output_replicated for teacher path and add kl_loss/nll_loss to aggregated results - Make arrow config self-contained with explicit settings Made-with: Cursor
Made-with: Cursor
Refactor teacher logit sharing to use per-microbatch IPC buffers instead of accumulating all teacher logits post-loop. Update loss functions to handle optional microbatch indexing. Bump config to TP=4 and 10k steps. Made-with: Cursor
Made-with: Cursor
- Add `use_ipc` config flag to switch between IPC (in-process communication) and non-IPC (data-dict) teacher logprob paths - Simplify KL loss to use (k+1)-dim distributions with a "rest" bucket, unifying IPC and non-IPC code paths - Branch teacher inference and student training for both train and validation loops based on use_ipc setting - Update submit script with IPC test experiment config Made-with: Cursor
- Add x_token/ module with TokenAligner for cross-vocabulary distillation - Rewrite CrossTokenizerDistillationLossFn with chunk-averaged KL that handles 1:1, 1:many, many:1, and many:many token alignments, matching the original TokenAligner.compute_KL_loss_optimized() exactly (verified via sanity check: 0.00% difference) - Fix teacher IPC to send full logits (topk_logits=None) instead of topk_logits=0 which produced empty tensors - Pass global_top_indices through to projection for memory optimization (2.3GB -> 125MB for projection output tensor) - Add cross-tokenizer data processing in training loop (dual tokenization, alignment, teacher data dict) - Add unbuffered stdout for better SLURM log visibility - Add example config and sanity check script Made-with: Cursor
Documents architecture, all new/modified files, usage instructions, configuration reference, and design decisions for the cross-tokenizer off-policy distillation feature built on NeMo RL v0.5.0. Made-with: Cursor
…compat - Switch default teacher from Qwen3-8B to Phi-4-mini-instruct - Add gold loss (common-vocab KL + uncommon-vocab L1) and xtoken loss modes - Add CE loss with dynamic loss scaling option - Replace dense projection with CSR sparse matmul for memory efficiency - Add MMLU 5-shot evaluation benchmark - Fix NotRequired import for Python <3.11 compatibility (16 files) - Add submit_cross_tokenizer.sh sbatch script - Add sanity check script for alignment and loss verification - Update LR schedule for 80k step training (warmup 4k, cosine 76k) Made-with: Cursor
Made-with: Cursor
…parse projection - Cache CrossTokenizerDistillationLossFn on policy workers at init and pass None to train() calls, eliminating repeated Ray serialization of the loss function (which includes large sparse matrices) each step. - Add set_loss_fn() and update_cross_tokenizer_data() to Policy and DTensorPolicyWorkerV2 to support per-step cross-tokenizer data updates. - Optimize sparse token projection by pre-reducing the sparse matrix with index_select before projection instead of projecting full vocab and slicing afterward. - Use AutoConfig.from_pretrained() for vocab sizes in sanity check script. Made-with: Cursor
…P rank Reduced training time with this optimization. Avoid Ray serialization of the loss function by having each worker construct CrossTokenizerDistillationLossFn from config + shared filesystem. Shard teacher_input_ids and aligned_pairs per data-parallel rank instead of broadcasting the full batch to every worker. Made-with: Cursor
- Add O(n+m) character-offset alignment via two-pointer walk on tokenizer offset mappings, with automatic DP fallback for failed samples - Precompute canonical token ID maps at startup to skip convert_ids_to_tokens - Add Numba JIT-accelerated DP kernel and banded DP variant - Add KD preprocessor preserving raw text for teacher tokenization - Add numba dependency - Update config: expand arrow data glob, set max_num_epochs=1 - Update submit script: bump max_num_steps=10, rename experiment to raw-text-kd-16node Made-with: Cursor
Introduce CUDA kernel and Python integration module for faster TokenAligner dynamic programming base-case computation. Made-with: Cursor
… preprocessing with current-step GPU training while keeping alignment behavior unchanged. Add explicit token-aligner runtime switches (`use_char_offset`, `use_align_fast`, CUDA-DP toggles), clean up dead/duplicated paths, and simplify the step orchestration with typed prefetch payloads and helper extraction for maintainability. Made-with: Cursor
Set explicit token aligner defaults and document the total-GPUs/2 heuristic for cross_tokenizer_num_workers so large-batch off-policy runs can iterate on stable, reproducible CT pool sizing. Made-with: Cursor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information